{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(73244:281473741362736,MainProcess):2021-03-02-10:56:44.683.806 [mindspore/_check_version.py:207] MindSpore version 1.1.1 and \"te\" wheel package version 1.0 does not match, reference to the match info on: https://www.mindspore.cn/install\n",
      "MindSpore version 1.1.1 and \"topi\" wheel package version 0.6.0 does not match, reference to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(73244:281473741362736,MainProcess):2021-03-02-10:56:45.240.993 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import mindspore\n",
    "import mindspore.nn as nn\n",
    "from mindspore import context\n",
    "from mindspore.train.model import Model\n",
    "from mindspore.nn.metrics import Accuracy\n",
    "from mindspore.train.serialization import load_checkpoint, load_param_into_net\n",
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor\n",
    "\n",
    "from src.config import cfg\n",
    "from src.textcnn import TextCNN\n",
    "from src.dataset import MovieReview"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target, device_id=cfg.device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)\n",
    "dataset = instance.create_train_dataset(batch_size=cfg.batch_size,epoch_size=cfg.epoch_size)\n",
    "batch_num = dataset.get_dataset_size() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = []\n",
    "warm_up = [1e-3 / math.floor(cfg.epoch_size / 5) * (i + 1) for _ in range(batch_num) for i in \n",
    "           range(math.floor(cfg.epoch_size / 5))]\n",
    "shrink = [1e-3 / (16 * (i + 1)) for _ in range(batch_num) for i in range(math.floor(cfg.epoch_size * 3 / 5))]\n",
    "normal_run = [1e-3 for _ in range(batch_num) for i in \n",
    "              range(cfg.epoch_size - math.floor(cfg.epoch_size / 5) - math.floor(cfg.epoch_size * 2 / 5))]\n",
    "learning_rate = learning_rate + warm_up + normal_run + shrink"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = TextCNN(vocab_len=instance.get_dict_len(), word_len=cfg.word_len, \n",
    "              num_classes=cfg.num_classes, vec_length=cfg.vec_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Continue training if set pre_trained to be True\n",
    "if cfg.pre_trained:\n",
    "    param_dict = load_checkpoint(cfg.checkpoint_path)\n",
    "    load_param_into_net(net, param_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), \n",
    "              learning_rate=learning_rate, weight_decay=cfg.weight_decay)\n",
    "loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_ck = CheckpointConfig(save_checkpoint_steps=int(cfg.epoch_size*batch_num/2), keep_checkpoint_max=cfg.keep_checkpoint_max)\n",
    "time_cb = TimeMonitor(data_size=batch_num)\n",
    "ckpt_save_dir = \"./ckpt\"\n",
    "ckpoint_cb = ModelCheckpoint(prefix=\"train_textcnn\", directory=ckpt_save_dir, config=config_ck)\n",
    "loss_cb = LossMonitor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 596, loss is 0.04297541\n",
      "epoch time: 52296.017 ms, per step time: 87.745 ms\n",
      "epoch: 2 step: 596, loss is 0.0065871133\n",
      "epoch time: 4298.849 ms, per step time: 7.213 ms\n",
      "epoch: 3 step: 596, loss is 0.0002644311\n",
      "epoch time: 4260.524 ms, per step time: 7.149 ms\n",
      "epoch: 4 step: 596, loss is 0.0017103986\n",
      "epoch time: 4296.318 ms, per step time: 7.209 ms\n",
      "train success\n"
     ]
    }
   ],
   "source": [
    "model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])\n",
    "print(\"train success\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#验证\n",
    "instance = MovieReview(root_dir=cfg.data_path, maxlen=cfg.word_len, split=0.9)\n",
    "dataset = instance.create_train_dataset(batch_size=cfg.batch_size,epoch_size=cfg.epoch_size)\n",
    "batch_num = dataset.get_dataset_size() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = './ckpt/train_textcnn-4_596.ckpt'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load checkpoint from [./ckpt/train_textcnn-4_596.ckpt].\n",
      "accuracy:  {'acc': 0.76171875}\n"
     ]
    }
   ],
   "source": [
    "dataset = instance.create_test_dataset(batch_size=cfg.batch_size)\n",
    "opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), \n",
    "              learning_rate=0.001, weight_decay=cfg.weight_decay)\n",
    "loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)\n",
    "net = TextCNN(vocab_len=instance.get_dict_len(),word_len=cfg.word_len,\n",
    "                  num_classes=cfg.num_classes,vec_length=cfg.vec_length)\n",
    "\n",
    "if checkpoint_path is not None:\n",
    "    param_dict = load_checkpoint(checkpoint_path)\n",
    "    print(\"load checkpoint from [{}].\".format(checkpoint_path))\n",
    "else:\n",
    "    param_dict = load_checkpoint(cfg.checkpoint_path)\n",
    "    print(\"load checkpoint from [{}].\".format(cfg.checkpoint_path))\n",
    "\n",
    "load_param_into_net(net, param_dict)\n",
    "net.set_train(False)\n",
    "model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc': Accuracy()})\n",
    "\n",
    "acc = model.eval(dataset)\n",
    "print(\"accuracy: \", acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
